/**
* Copyright (c) 2009-2012, Regents of the University of Colorado
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package com.googlecode.clearnlp.classification.model;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.classification.vector.SparseFeatureVector;
import com.googlecode.clearnlp.collection.map.ObjectIntHashMap;
import com.googlecode.clearnlp.util.pair.Pair;

/**
 * Abstract model.
 * @since 1.0.0
 * @author Jinho D. Choi ({@code jdchoi77@gmail.com})
 */
abstract public class AbstractModel
{
	protected final Logger LOG = LoggerFactory.getLogger(this.getClass());
	
	/** The total number of labels. */
	protected int      n_labels;
	/** The total number of features. */
	protected int      n_features;
	/** The weight vector for all labels. */
	protected double[] d_weights;
	/** The list of all labels. */
	protected String[] a_labels;
	/** The map between labels and their indices. */
	protected ObjectIntHashMap<String> m_labels;
	/** The type of a solver algorithm. */
	protected byte i_solver;
	
	/** Constructs an abstract model for training. */
	public AbstractModel()
	{
		n_labels   = 0;
		n_features = 1;
		m_labels   = new ObjectIntHashMap<String>();
	}
	
	/**
	 * Constructs an abstract model for decoding.
	 * @param reader the reader to load the model from.
	 */
	public AbstractModel(ObjectInputStream reader)
	{
		load(reader);
	}
	
	/**
	 * Loads this model from the specific stream.
	 * @param in the reader to load the model from.
	 */
	abstract public void load(ObjectInputStream in);
	
	/**
	 * Saves this model to the specific stream.
	 * @param out the stream to save this model to.
	 */
	abstract public void save(ObjectOutputStream out);
	
	// ========================= INITIALIZATION =========================
	
	/**
	 * Initializes the label array after adding all labels.
	 * @see StringModel#addLabel(String)
	 */
	public void initLabelArray()
	{
		a_labels = new String[n_labels];
		String label;
		
		for (ObjectCursor<String> cur : m_labels.keys())
		{
			label = cur.value;
			a_labels[getLabelIndex(label)] = label;
		}
	}
	
	/** Initializes the weight vector given the label and feature sizes. */
	public void initWeightVector()
	{
		d_weights = isBinaryLabel() ? new double[n_features] : new double[n_features * n_labels];
	}
	
	// ========================= GETTER =========================
	
	/** @return the total number of labels in this model. */
	public int getLabelSize()
	{
		return n_labels;
	}
	
	/** @return the total number of features in this model. */
	public int getFeatureSize()
	{
		return n_features;
	}
	
	/**
	 * Returns the index of the specific label.
	 * Returns {@code -1} if the label is not found in this model.
	 * @param label the label to get the index for.
	 * @return the index of the specific label.
	 */
	public int getLabelIndex(String label)
	{
		return m_labels.get(label) - 1;
	}
	
	/** @return the index of the weight vector given the label and the feature index. */
	protected int getWeightIndex(int label, int index)
	{
		return index * n_labels + label;
	}
	
	public byte getSolver()
	{
		return i_solver;
	}
	
	public String getLabel(int index)
	{
		return a_labels[index];
	}
	
	public String[] getLabels()
	{
		return a_labels;
	}
	
	public double[] getWeights()
	{
		return d_weights;
	}
	
	public double[] getWeights(int label)
	{
		double[] weights = new double[n_features];
		int i;
		
		for (i=0; i<n_features; i++)
			weights[i] = d_weights[getWeightIndex(label, i)];
		
		return weights;
	}
	
	// ========================= SETTER =========================
	
	public void setSolver(byte solver)
	{
		i_solver = solver;
	}
	
	/**
	 * Adds the specific label to this model.
	 * @param label the label to be added.
	 */
	public void addLabel(String label)
	{
		if (!m_labels.containsKey(label))
			m_labels.put(label, ++n_labels);
	}
	
	public void setWeights(double[] weights)
	{
		d_weights = weights; 
	}
	
	/**
	 * Copies a weight vector for binary classification.
	 * @param weights the weight vector to be copied. 
	 */
	public void copyWeights(double[] weights)
	{
		System.arraycopy(weights, 0, d_weights, 0, n_features);
	}
	
	/**
	 * Copies a weight vector of the specific label (for multi-classification).
	 * @param weights the weight vector to be copied.
	 * @param label the label of the weight vector.
	 */
	public void copyWeights(double[] weights, int label)
	{
		int i;
		
		for (i=0; i<n_features; i++)
			d_weights[getWeightIndex(label, i)] = weights[i];
	}
	
	// ========================= BOOLEAN =========================
	
	/** @return {@code true} if this model contains only 2 labels. */
	public boolean isBinaryLabel()
	{
		return n_labels == 2;
	}
	
	/**
	 * @param featureIndex the index of the feature.
	 * @return {@code true} if the specific feature index is within the range of this model.
	 */
	public boolean isRange(int featureIndex)
	{
		return 0 < featureIndex && featureIndex < n_features;
	}
	
	// ========================= LOAD/SAVE =========================
	
	/** @throws IOException */
	@SuppressWarnings("unchecked")
	protected void loadDefault(ObjectInputStream in) throws Exception
	{
		i_solver   = (Byte)in.readObject();	LOG.info(".");
		a_labels   = (String[])in.readObject();	LOG.info(".");
		m_labels   = (ObjectIntHashMap<String>)in.readObject();	LOG.info(".");
		d_weights  = (double[])in.readObject();	LOG.info(".");
		
		n_labels   = a_labels.length;
		n_features = d_weights.length;
		if (!isBinaryLabel()) n_features /= n_labels;
	}
	
	/** @throws IOException */
	protected void saveDefault(ObjectOutputStream out) throws IOException
	{
		out.writeObject(new Byte(i_solver));	LOG.info(".");
		out.writeObject(a_labels);				LOG.info(".");
		out.writeObject(m_labels);				LOG.info(".");
		out.writeObject(d_weights);				LOG.info(".");
	}
	
	// ========================= SCORES =========================
	
	/**
	 * For binary classification, this method calls {@link #getScoresBinary(SparseFeatureVector)}.
	 * For multi-classification, this method calls {@link #getScoresMulti(SparseFeatureVector)}.
	 * @param x the feature vector.
	 * @return the scores of all labels given the feature vector.
	 */
	public double[] getScores(SparseFeatureVector x)
	{
		return isBinaryLabel() ? getScoresBinary(x) : getScoresMulti(x);
	}

	/**
	 * @param x the feature vector.
	 * @return the scores of all labels given the feature vector.
	 */
	private double[] getScoresBinary(SparseFeatureVector x)
	{
		double score = d_weights[0];
		int    i, index, size = x.size();
		
		for (i=0; i<size; i++)
		{
			index = x.getIndex(i);
			
			if (isRange(index))
			{
				if (x.hasWeight())
					score += d_weights[index] * x.getWeight(i);
				else
					score += d_weights[index];
			}
		}
		
		double[] scores = {score, -score};
		return scores;
	}
	
	/**
	 * @param x the feature vector.
	 * @return the scores of all labels given the feature vector.
	 */
	private double[] getScoresMulti(SparseFeatureVector x)
	{
		double[] scores = Arrays.copyOf(d_weights, n_labels);
		int      i, index, label, weightIndex, size = x.size();
		double   weight = 1;
		
		for (i=0; i<size; i++)
		{
			index = x.getIndex(i);
			if (x.hasWeight())	weight = x.getWeight(i);
			
			if (isRange(index))
			{
				for (label=0; label<n_labels; label++)
				{
					weightIndex = getWeightIndex(label, index);
					
					if (x.hasWeight())	scores[label] += d_weights[weightIndex] * weight;
					else				scores[label] += d_weights[weightIndex];
				}
			}
		}
		
		return scores;
	}
	
	/**
	 * Returns the best prediction given the feature vector.
	 * @param x the feature vector.
	 * @return the best prediction given the feature vector.
	 */
	public StringPrediction predictBest(SparseFeatureVector x)
	{
		return Collections.min(getPredictions(x));
	}
	
	/**
	 * Returns the first and second best predictions given the feature vector.
	 * @param x the feature vector.
	 * @return the first and second best predictions given the feature vector.
	 */
	public Pair<StringPrediction,StringPrediction> predictTwo(SparseFeatureVector x)
	{
		return predictTwo(getPredictions(x));
	}
	
	public Pair<StringPrediction,StringPrediction> predictTwo(List<StringPrediction> list)
	{
		StringPrediction fst = list.get(0), snd = list.get(1), p;
		int i, size = list.size();
		
		if (fst.score < snd.score)
		{
			fst = snd;
			snd = list.get(0);
		}
		
		for (i=2; i<size; i++)
		{
			p = list.get(i);
			
			if (fst.score < p.score)
			{
				snd = fst;
				fst = p;
			}
			else if (snd.score < p.score)
				snd = p;
		}
		
		return new Pair<StringPrediction,StringPrediction>(fst, snd);
	}
	
	/**
	 * Returns a sorted list of predictions given the specific feature vector.
	 * @param x the feature vector.
	 * @return a sorted list of predictions given the specific feature vector.
	 */
	public List<StringPrediction> predictAll(SparseFeatureVector x)
	{
		List<StringPrediction> list = getPredictions(x);
		Collections.sort(list);
		
		return list;
	}
	
	/**
	 * Returns an unsorted list of predictions given the specific feature vector.
	 * @param x the feature vector.
	 * @return an unsorted list of predictions given the specific feature vector.
	 */
	public List<StringPrediction> getPredictions(SparseFeatureVector x)
	{
		List<StringPrediction> list = new ArrayList<StringPrediction>(n_labels);
		double[] scores = getScores(x);
		int i;
		
		for (i=0; i<n_labels; i++)
			list.add(new StringPrediction(a_labels[i], scores[i]));
		
		return list;		
	}

	
	
	
	
	
	
	
	
	
	static public String LABEL_TRUE  = "T";
	static public String LABEL_FALSE = "F";
	
	static public String getBooleanLabel(boolean b)
	{
		return b ? LABEL_TRUE : LABEL_FALSE;
	}
	
	static public boolean toBoolean(String label)
	{
		return label.equals(LABEL_TRUE);
	}
}
